import torch
import math
from torch.utils.data import DataLoader
import myloss
import model
import dataset
import test
import config
import time
from myconfig import *

# weight_path = "F:\\study\\code\\mycode_3-22_with_relu\\weights\\"
# weight_name = "weight26_0.0000543_0.0216863_end_0.0000598_0.0008944.pth"
# dataset_path = "F:\\study\\mydataset"

# dataset_path = "G:\\0A\output\\data-input\\output\\val"
batch_size = 10
num_workers = 0


false_negatives = 0
true_positives = 0
true_negatives = 0
false_positives = 0
position_errors = []
direction_errors = []

image_index = 0
image_name = ""

def direction_diff(direction_a, direction_b):
    """Calculate the angle between two direction."""
    diff = abs(direction_a - direction_b)
    return diff if diff < math.pi else 2 * math.pi - diff


def calc_point_squre_dist(point_a, point_b):
    """Calculate distance between two marking points."""
    distx = point_a.x - point_b.x
    disty = point_a.y - point_b.y
    return distx ** 2 + disty ** 2


def calc_point_direction_angle(point_a, point_b):
    """Calculate angle between direction in rad."""
    return direction_diff(point_a.direction, point_b.direction)


def match_marking_points(point_a, point_b):
    """Determine whether a detected point match ground truth."""
    dist_square = calc_point_squre_dist(point_a, point_b)
    angle = calc_point_direction_angle(point_a, point_b)
    # if point_a.shape > 0.5 and point_b.shape < 0.5:
    #     return False
    # if point_a.shape < 0.5 and point_b.shape > 0.5:
    #     return False
    return (dist_square < config.SQUARED_DISTANCE_THRESH)
            # and angle < config.DIRECTION_ANGLE_THRESH)


def match_gt_with_preds(intput, list):
    """Match a ground truth with every predictions and return matched index."""
    max_confidence = 0.
    matched_idx = -1
    for i, pred in enumerate(list):
        if match_marking_points(intput[1], pred[1]) and max_confidence < pred[0]:
            max_confidence = pred[0]
            matched_idx = i
    return matched_idx


def cal_one_bach(gt_heatamp_pos, gt_heatamp_angle, per_heatamp_pos, pre_heatamp_angle):
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    gt_list = test.get_position(gt_heatamp_pos)
    gt_list = test.get_angle(gt_list, gt_heatamp_angle)
    pre_list = test.get_position(per_heatamp_pos)
    pre_list = test.get_angle(pre_list, pre_heatamp_angle)
    for pre in pre_list:
        idx = match_gt_with_preds(pre, gt_list)
        if idx >= 0:
            tp += 1
            position_errors.append(calc_point_squre_dist(pre[1], gt_list[idx][1]))
            direction_errors.append(calc_point_direction_angle(pre[1], gt_list[idx][1]))
        else:
            fp += 1
    fn = len(gt_list) - tp
    return tp, fp, fn


def cal_one_iter(gt_heatamp_poss, gt_heatamp_angles, per_heatamp_poss, pre_heatamp_angles):
    for idx in range(batch_size):
        tp, fp, fn = cal_one_bach(gt_heatamp_poss[idx], gt_heatamp_angles[idx], per_heatamp_poss[idx], pre_heatamp_angles[idx])
        global true_positives
        global false_positives
        global false_negatives
        true_positives = true_positives + tp
        false_positives = false_positives + fp
        false_negatives = false_negatives + fn



def evaluate():
    net = model.PSNet()
    net = torch.nn.DataParallel(net)
    net = net.cuda()
    device = torch. device('cuda:0')
    net.module.load_state_dict(torch.load(weight_path+weight_name, map_location='cuda:0'),strict=False)
    net.eval()
    data = dataset.ParkingSlotDataset(dataset_path)
    data_loader = DataLoader(data,
                             batch_size=batch_size, shuffle=False,
                             num_workers=num_workers,
                             collate_fn=lambda x: list(zip(*x)))
    num_iter = data_loader.__len__()
    with torch.no_grad():
        for iter_idx, (images, target_pos, target_angle, marking_points) in enumerate(data_loader):
            if len(images) < batch_size:
                break
            global image_index, image_name
            image_name=data.sample_names[image_index]
            images = torch.stack(images).to(device)
            start = time.time()
            output_pos, output_angle = net(images)

            target_pos = torch.stack(target_pos)
            target_pos = target_pos.to(device)
            target_angle = torch.stack(target_angle).to(device)
            cal_one_iter(target_pos, target_angle, output_pos, output_angle)
            image_index = image_index + 1
            end = time.time()
            print(end-start, "***********")
            print(iter_idx, num_iter)
    print(false_negatives, true_positives, true_negatives, false_positives)
    precision = true_positives/(true_positives + false_positives)
    recall = true_positives/(true_positives + false_negatives)
    pos_error = 0.0
    angle_error = 0.0
    for idx in range(len(position_errors)):
        pos_error += position_errors[idx]
        angle_error += direction_errors[idx]
    aver_pos_error = pos_error/len(position_errors)
    aver_angle_error = angle_error/len(position_errors)

    print(precision, recall, aver_pos_error, aver_angle_error)


if __name__ == '__main__':
    evaluate()

